#!/usr/bin/env python
import pathlib
import pickle
import unittest
from datetime import datetime
from typing import Dict

import numpy as np
import pandas as pd
from easydict import EasyDict
from ruamel import yaml

from quant.backtest import backtest
from quant.strategy import grid_by_lh, grid_by_mid_step, grid_by_mid_step_percent, TurningPointBuy, TurningPointSell, \
    TradeOp
from quant.trader import trader, rqalpha
from quant.utils import plot_trade_signals


class StrategyTest(unittest.TestCase):
    testdata = pathlib.Path(__file__).parent.joinpath('testdata')
    with pathlib.Path(__file__).parent.joinpath('backtest/backtest.yml').open(
            mode='r') as f:
        cfgs = yaml.load(f, Loader=yaml.UnsafeLoader)


class GridTest(StrategyTest):
    def assert_result(self, st: str, result: Dict, bt: EasyDict):
        f = self.testdata.joinpath('%s.pkl' % st)
        if f.exists():
            with f.open(mode='rb') as fp:
                expected = pickle.load(fp)
                np.testing.assert_array_equal(expected['signals'].signal,
                                              result['signals'].signal)
        f = self.testdata.joinpath('%s.png' % st)
        if not f.exists():
            fig = plot_trade_signals(result['signals'].reset_index(),
                                     'datetime', 'open', bt.vars.frequency,
                                     bt.vars.grid)
            fig.savefig(f)

    def test_grid_dfcf_1d(self):
        bt = EasyDict(self.cfgs['grid_dfcf_1d'])
        bt.vars.update({
            'start_date': '2021-01-01',
            'end_date': '2022-01-01',
            'grid': grid_by_lh(26, 40, 10),
        })
        strategies = bt.strategies
        for name in ['basic', 'enhanced']:
            bt.strategies = {name: strategies[name]}
            bt.backtest_dir = self.testdata
            for task in backtest.make_tasks(bt):
                task.set_v(0)
                result = backtest.run_strategy(pathlib.Path('/tmp'), task,
                                               bt.config)
                self.assert_result(name, result, bt)

    def test_grid_dfcf_1m(self):
        bt = EasyDict(self.cfgs['grid_dfcf_1m'])
        bt.vars.update({
            'start_date': '2022-01-01',
            'end_date': '2022-02-01',
            'grid': grid_by_lh(26, 40, 10),
        })
        strategies = bt.strategies
        bt.strategies = {'dtp': strategies['dtp']}
        bt.backtest_dir = self.testdata
        for task in backtest.make_tasks(bt):
            task.set_v(0)
            result = backtest.run_strategy(pathlib.Path('/tmp'), task,
                                           bt.config)
            self.assert_result('dtp', result, bt)

    @staticmethod
    def test_grid_by_fn():
        actual = grid_by_mid_step(5, 1, 10)
        np.testing.assert_array_equal(
            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], actual)

        actual = grid_by_mid_step_percent(5, 0.2, 10)
        np.testing.assert_array_almost_equal(
            [2.048, 2.56, 3.2, 4.0, 5.0, 6.0, 7.2, 8.64, 10.368, 12.4416],
            actual)

        actual = grid_by_lh(1, 10, 10)
        np.testing.assert_array_equal(
            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], actual)

    def test_trading_date_in_n(self):
        self.assertEqual(trader.Trader.trading_date_in_n('2022-04-01', 0),
                         '2022-04-01')
        self.assertEqual(trader.Trader.trading_date_in_n('2022-04-01', 1),
                         '2022-04-06')
        self.assertEqual(trader.Trader.trading_date_in_n('2022-04-02', 0),
                         '2022-04-01')
        self.assertEqual(trader.Trader.trading_date_in_n('2022-04-02', 1),
                         '2022-04-06')
        self.assertEqual(trader.Trader.trading_date_in_n('2022-04-02', -1),
                         '2022-04-01')
        self.assertEqual(trader.Trader.trading_date_in_n('2022-04-06', -1),
                         '2022-04-01')


class ConbondTest(StrategyTest):
    def conbond_test(self, name: str, bt_vars: Dict, params_vars: Dict):
        bt = EasyDict(self.cfgs['conbond'])
        bt.vars.update(bt_vars)
        strategies = bt.strategies
        bt.strategies = {name: strategies[name]}
        bt.strategies[name].params_vars = params_vars
        bt.backtest_dir = self.testdata
        for task in backtest.make_tasks(bt):
            task.set_v(0)
            result = backtest.run_strategy(pathlib.Path('/tmp'), task,
                                           bt.config)
            f = self.testdata.joinpath('%s.pkl' % name)
            if f.exists():
                with f.open(mode='rb') as fp:
                    expected = pickle.load(fp)
                pd.testing.assert_frame_equal(
                    expected['signals'].reset_index()[['order_book_id']],
                    result['signals'].reset_index()[['order_book_id']])

    def test_double_low(self):
        self.conbond_test(
            'double_low', {
                'start_date': '2021-06-01',
                'end_date': '2021-07-01',
                'name': 'double_low',
            }, {
                'threshold': [10000],
                'rebalance': ['周'],
                'days_to_maturity': [7],
                'days_to_stop_trading': [7],
                'top': [20],
            })

    def test_double_low_rank(self):
        self.conbond_test(
            'double_low_rank', {
                'start_date': '2021-06-01',
                'end_date': '2021-07-01',
                'name': 'double_low_rank',
            }, {
                'threshold': [10000],
                'rebalance': ['周'],
                'days_to_maturity': [7],
                'days_to_stop_trading': [7],
                'top': [20],
                'weight': [{
                    'weight_price': 0.5,
                    'weight_cpr': 0.5,
                }]
            })


class ConditionTradeTest(unittest.TestCase):
    def test_turning_point_buy(self):
        st = TurningPointBuy(threshold=1)
        testdata = pd.DataFrame.from_dict(
            {'open': [10, 9.6, 9.3, 9.5, 9.7, 9.2, 9.8, 10.3, 10.4, 9.9]})
        testdata['datetime'] = pd.date_range(start=datetime.now(),
                                             periods=len(testdata),
                                             freq='m')
        expected = [
            TradeOp.SKIP_BUY, TradeOp.SKIP_BUY, TradeOp.SKIP_BUY,
            TradeOp.SKIP_BUY, TradeOp.SKIP_BUY, TradeOp.SKIP_BUY,
            TradeOp.SKIP_BUY, TradeOp.BUY, TradeOp.BUY, TradeOp.SKIP_BUY
        ]
        for idx in testdata.index:
            dk = testdata.loc[idx]
            actual = st.generate_signal(pd.DataFrame([dk]),
                                        datetime.now()).iloc[0]
            self.assertEqual(actual.signal, expected[idx])

    def test_turning_point_sell(self):
        st = TurningPointSell(order_book_id='foo',
                              dt=datetime.now(), base=9.5, threshold=0.1)
        testdata = pd.DataFrame.from_dict(
            {'open': [10, 9.6, 10.3, 9.9, 10.6, 10, 9.8, 9.5, 9.3, 10]})
        testdata['datetime'] = pd.date_range(start=datetime.now(),
                                             periods=len(testdata),
                                             freq='m')
        expected = [
            TradeOp.SKIP_SELL, TradeOp.SKIP_SELL, TradeOp.SKIP_SELL,
            TradeOp.SKIP_SELL, TradeOp.SKIP_SELL, TradeOp.SKIP_SELL,
            TradeOp.SKIP_SELL, TradeOp.SELL, TradeOp.SELL, TradeOp.SKIP_SELL
        ]
        for idx in testdata.index:
            dk = testdata.loc[idx]
            actual = st.generate_signal(pd.DataFrame([dk]),
                                        datetime.now()).iloc[0]
            self.assertEqual(actual.signal, expected[idx])


if __name__ == '__main__':
    unittest.main()
